Support Vector Machines for Classification

What are Support Vector Machines (SVM)?

Support vector machines are remarkable for the unification of geometric theory, elegant mathematics, theoretical guarantees with strong practical use cases.

History

The theory for SVM was developed by Vladimir Vapnik and colleagues at AT&T Bell Laboratories1 in 1963. They are one of the most robust prediction methods as they’re based on statistical learning framework called VC Theory developed by Vapnik and Chervonenkis.

Properties of SVM with soft-margin are examples of empirical risk minimization (ERM) algorithm with hinge loss function. SVMs belong to a natural class of algorithms for statistical inference, which also happens to produce really good predictions.

Inference + Prediction = Data Science. What more can you ask?

Vapnik showing off his ERM framework, taking a jibe on superiority to Bayesian statistics. (ERM formula at top.)

How do they work?

Geometrically, SVM tries to find a linear hyperplane that separates the data into two classes.

Consider the following example where I’m selecting two species of Iris flowers and plotting their sepal width and sepal length. The colour represents the species.

library(tidyverse)
theme_set(ggthemes::theme_clean())

p = iris |> 
   filter(Species != "versicolor") |> 
   ggplot(aes(x = Sepal.Length, y = Sepal.Width, colour = Species)) +
   geom_point()
p

There seems to be a clear separation between the two species. Can we draw a (straight) line that separates them?

p = p +
   geom_segment(aes(x = 4, y = 2, xend = 7, yend = 4.5), colour = 4, lty = 2, alpha = 0.7)
p

Except for one setosa which is misclassified, we got them all right.

However, there are infinitely many other lines possible.

p = p +
   geom_segment(aes(x = 4.4, y = 2, xend = 6.5, yend = 4.5), colour = 5, lty = 2, alpha = 0.7) +
   geom_segment(aes(x = 5, y = 2, xend = 6.5, yend = 4), colour = 6, lty = 2, alpha = 0.7)
p

Any many many more.

Paradox of Choices

Since there are so many choices in deciding the best model, we need to define the problem more rigorously.

What would be the “best” hyperplane separating the two classes?

One way to visualize this problem is to think how could we maximize the distance between two classes. Therefore, the best partitioning hyperplane would maximize the distance between the two classes.

Think again of the classification problem we have.2

iris |> 
   filter(Species != "versicolor") |>
   filter(Sepal.Length > 5) |> 
   ggplot(aes(x = Sepal.Length, y = Sepal.Width, colour = Species)) +
   geom_point()

What is a good “margin”?

By visual examination, choose between the three options: which one is the best separating hyperplane?

Option A

Option B

Option C

Mathematically…

The middle line of the margin is \(w'x + b = 0\) while the top and bottom lines are \(w'x + b = -1\) and \(w'x + b = 1\).

For any unseen point,

\[ f(x) = \begin{cases} 1 & \text{if} & w'x+b \geq 1 \\ -1 & \text{if} &w'x + b \leq -1 \end{cases} \]

The margin width is \(\frac{2}{||w||^2}\), which has to maximized. This is equivalent to minimizing \(\frac{||w||^2}{2}\), subject to the constraints:

\[ f(x) = \begin{cases} 1 & \text{if} & w'x+b \geq 1 \\ -1 & \text{if} &w'x + b \leq -1 \end{cases} \]

This is a constrained optimization problem that can be solved via many methods (numerical, quadratic optimization, etc.).

What if they’re not separable?

In our dummy example, I removed two points. But that is usually not a good idea. Can you exclude points from your data because they’re hard to classify?3

That’s a blunder for two reasons.

  1. First, we want to build a model that works for all data points — including extreme data points. We will not know if a test point is an extreme point.
  2. Second, how will you decide which points to remove? If you remove all tough cases, why even use SVM? A simple linear regression can do a reasonably good work forecasting some points.

Let’s take a look at a problem when the classes are not perfectly separable.

This one “blue” point is being misclassified. Can we do something about it?

Here come slack variables to rescue…

Slack variables (\(\xi\)) add a “padding” around the margin which vary by observation. For data on the wrong side of margin, the modified objective function’s value is proportional to its distance from the margin.

This is called “soft” margin.

Optimisation Problem

\[ \min L(w) = \frac{||w||^2}{2} + C\left( \sum_{i = 1}^N \xi_i^k \right) \]

subject to constraints

\[ f(x_i) = \begin{cases} 1 & \text{if} & w'x+b \geq 1 - \xi_i \\ -1 & \text{if} &w'x + b \leq -1 + \xi_i \end{cases}. \]

Another alternative: Non-linear SVM

What if the data has a non-linear trend, like the example below? A linear hyperplane does not make sense at all in that case.

We can map our features to a new feature space where they are linearly separable. Recall that we usually take natural logarithm of wealth before using them in linear regression. The concept is similar, except that it is very expansive and works for many cases.

Kernel Functions

We can also create non-linear classifiers by applying “kernel trick”. It is a commonly known technique in statistics which converts lower-dimensional functions to higher-dimensional functions. Generally, it is easier to spot clear decision boundaries in higher-dimensions.

The resulting algorithm is largely similar except that every dot product is replaced by a nonlinear kernel function. Then, the algorithm can fit the maximum-margin hyperplane in a transformed feature space.

Note that the transformation might be nonlinear and the new space can be high dimensional. The classifier will be a linear hyperplane in the new space but might be nonlinear in the original input space.

Some Common Kernels

Polynomial Kernel

\[ k(x_i, x_j) = (x_i' x_j + 1)^d, \]

when \(d = 1\), this is linear kernel; \(d = 2\), this is quadratic kernel.

Radial Basis Kernel / Gaussian Kernel

\[ k(x_i, x_j) = \exp(-\gamma ||x_i - x_j||^2), \]

for all \(\gamma > 0\). When \(\gamma = 1/2\sigma^2\), this is known to have a width \(\sigma\). It is also known as Radial Basis Function (RBF).

Characteristics of SVM

  1. SVM performs best on average and can outperform most other techniques across many important applications.
  2. The effectiveness of SVM in practice depends on (a) the choice of kernel, (b) kernel’s parameters, and (c) soft-margin parameter \(C\).
  3. Gaussian Kernel (or RBF) is a common choice for kernel function.
  4. Being a statistically-oriented method, the results are stable, reproducible and largely independent of the specific optimisation algorithm.
  5. Being a convex optimisation problem leads to the global optimum.
  6. Computational challenge: solving the optimisation problem has a quadratic complexity. While this is not too bad but using Kernel spaces increases the number of features exaggerating the problem multifold.

Example: Vanilla SVM

In this example, we will try to predict the type of animal given it’s other characteristics using linear SVM, aka vanilla SVM. This is the zoo data from mlbench package.

Let’s see the data.

data(Zoo, package = "mlbench")
Zoo = as_tibble(Zoo)
Zoo |>
   DT::datatable()

Correlation

Let’s do some descriptive statistics and explore how the data looks. How do different types of animals vary? Can we see a quick correlation?

library(ggcorrplot)
model.matrix(~0+., data = Zoo) |>  
  cor(use="pairwise.complete.obs") |> 
  ggcorrplot(show.diag = F, type="lower", lab=TRUE, lab_size=2)

This can tell us a lot of interesting insights!

Modelling

Fitting Model

But our time is limited, so jumping to SVM. Recall that train() takes formula as the first input, data as the second input, method as the third input and other training controls.

library(caret)
svmFit = train(
   type ~., 
   data = Zoo, 
   method = "svmLinear",
   trControl = trainControl(method = "cv", number = 10)
)
svmFit
## Support Vector Machines with Linear Kernel 
## 
## 101 samples
##  16 predictor
##   7 classes: 'mammal', 'bird', 'reptile', 'fish', 'amphibian', 'insect', 'mollusc.et.al' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 92, 90, 91, 91, 91, 90, ... 
## Resampling results:
## 
##   Accuracy   Kappa    
##   0.9788889  0.9707841
## 
## Tuning parameter 'C' was held constant at a value of 1

Let’s see details about the final model.

# storing final model
svmFinal = svmFit$finalModel
svmFinal
## Support Vector Machine object of class "ksvm" 
## 
## SV type: C-svc  (classification) 
##  parameter : cost C = 1 
## 
## Linear (vanilla) kernel function. 
## 
## Number of Support Vectors : 47 
## 
## Objective Function Value : -0.1448 -0.218 -0.1484 -0.1754 -0.0936 -0.1033 -0.297 -0.0819 -0.1556 -0.0907 -0.1135 -0.182 -0.5763 -0.13 -0.1833 -0.118 -0.0474 -0.0823 -0.1236 -0.1481 -0.5666 
## Training error : 0

Predictions

# creating predictions
pred = predict(svmFit, newdata = Zoo)
pred
##   [1] mammal        mammal        fish          mammal        mammal       
##   [6] mammal        mammal        fish          fish          mammal       
##  [11] mammal        bird          fish          mollusc.et.al mollusc.et.al
##  [16] mollusc.et.al bird          mammal        fish          mammal       
##  [21] bird          bird          mammal        bird          insect       
##  [26] amphibian     amphibian     mammal        mammal        mammal       
##  [31] insect        mammal        mammal        bird          fish         
##  [36] mammal        mammal        bird          fish          insect       
##  [41] insect        bird          insect        bird          mammal       
##  [46] mammal        mollusc.et.al mammal        mammal        mammal       
##  [51] mammal        insect        amphibian     mollusc.et.al mammal       
##  [56] mammal        bird          bird          bird          bird         
##  [61] fish          fish          reptile       mammal        mammal       
##  [66] mammal        mammal        mammal        mammal        mammal       
##  [71] mammal        bird          mollusc.et.al fish          mammal       
##  [76] mammal        reptile       mollusc.et.al bird          bird         
##  [81] reptile       mollusc.et.al fish          bird          mammal       
##  [86] mollusc.et.al fish          bird          insect        amphibian    
##  [91] reptile       reptile       fish          mammal        mammal       
##  [96] bird          mammal        insect        mammal        mollusc.et.al
## [101] bird         
## Levels: mammal bird reptile fish amphibian insect mollusc.et.al

Confusion Matrix

I’m predicting on training data, which is not advisable. But it shows how the SVM function works.

# confusion matrix
table(Zoo$type, pred)
##                pred
##                 mammal bird reptile fish amphibian insect mollusc.et.al
##   mammal            41    0       0    0         0      0             0
##   bird               0   20       0    0         0      0             0
##   reptile            0    0       5    0         0      0             0
##   fish               0    0       0   13         0      0             0
##   amphibian          0    0       0    0         4      0             0
##   insect             0    0       0    0         0      8             0
##   mollusc.et.al      0    0       0    0         0      0            10

Accuracy

# prediction accuracy
sum(Zoo$type==pred)/nrow(Zoo)
## [1] 1

  1. Now called Nokia Labs.↩︎

  2. You would notice that I’ve removed two points, just for simplicity.↩︎

  3. Try saying “these points are hard to classify so I’ll ignore them” to your client. Don’t. Unless you want to lose your job.↩︎